Conversation
WalkthroughThe changes introduce asynchronous streaming capabilities to the chat and RAG pipeline layers, allowing chat responses to be delivered incrementally as tokens. New Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant API (query endpoint)
participant Chat
participant RAGPipeline
participant ResponseSynthesizer
participant LLMModel
Client->>API (query endpoint): Send chat request (with stream flag)
API (query endpoint)->>Chat: Create ChatRequest (with stream)
alt stream == true
Chat->>RAGPipeline: astream(chat_request)
RAGPipeline->>ResponseSynthesizer: stream(retrieval_result)
ResponseSynthesizer->>LLMModel: stream(messages)
loop For each token
LLMModel-->>ResponseSynthesizer: yield token
ResponseSynthesizer-->>RAGPipeline: yield token
RAGPipeline-->>Chat: yield token
Chat-->>API (query endpoint): yield token
API (query endpoint)-->>Client: yield token (SSE)
end
else stream == false
Chat->>RAGPipeline: __acall__(chat_request)
RAGPipeline->>ResponseSynthesizer: __call__(retrieval_result)
ResponseSynthesizer->>LLMModel: create(messages)
LLMModel-->>ResponseSynthesizer: full response
ResponseSynthesizer-->>RAGPipeline: full response
RAGPipeline-->>Chat: full response
Chat-->>API (query endpoint): full response
API (query endpoint)-->>Client: full response
end
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 6
🔭 Outside diff range comments (1)
src/wandbot/models/llm.py (1)
439-442: Add error handling to the wrapper stream method.The wrapper method should handle errors consistently with the create method.
async def stream(self, messages: List[Dict[str, Any]], **kwargs): - async for token in self.model.stream(messages=messages, **kwargs): - yield token + try: + async for token in self.model.stream(messages=messages, **kwargs): + yield token + except Exception as e: + logger.error(f"LLMModel streaming error: {str(e)}") + # Fall back to non-streaming + result, _ = await self.create(messages=messages, **kwargs) + if result: + yield result
🧹 Nitpick comments (2)
src/wandbot/api/routers/chat.py (1)
48-54: Consider adding error handling and content-type headers for better streaming experience.The streaming implementation is functional but could benefit from enhanced error handling and proper headers.
if chat_req.stream: async def event_gen(): - async for token in chat_instance.astream(chat_req): - yield f"data: {token}\n\n" + try: + async for token in chat_instance.astream(chat_req): + yield f"data: {token}\n\n" + except Exception as e: + logger.error(f"Error during streaming: {e}") + yield f"data: [ERROR] {str(e)}\n\n" + finally: + yield "data: [DONE]\n\n" - return StreamingResponse(event_gen(), media_type="text/event-stream") + return StreamingResponse( + event_gen(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + } + )src/wandbot/models/llm.py (1)
386-390: Consider implementing true streaming for Google GenAI.The current implementation doesn't actually stream tokens incrementally - it just yields the full result at once. This defeats the purpose of streaming.
Consider investigating if Google GenAI supports streaming and implement it properly, or document why streaming isn't supported:
async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000): + # TODO: Implement actual streaming when Google GenAI supports it + # For now, fall back to full response yielding result, _ = await self.create(messages=messages, max_tokens=max_tokens) if result: yield result
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/wandbot/api/routers/chat.py(2 hunks)src/wandbot/chat/chat.py(1 hunks)src/wandbot/chat/rag.py(1 hunks)src/wandbot/chat/schemas.py(1 hunks)src/wandbot/models/llm.py(6 hunks)src/wandbot/rag/response_synthesis.py(1 hunks)tests/test_stream.py(1 hunks)
🔇 Additional comments (6)
src/wandbot/chat/schemas.py (1)
49-49: LGTM! Clean schema extension for streaming support.The addition of the
streamfield with a sensible default value maintains backward compatibility while enabling the new streaming functionality.src/wandbot/api/routers/chat.py (2)
2-2: LGTM! Appropriate import for streaming functionality.
40-46: LGTM! Clean ChatRequest construction with streaming support.The updated ChatRequest construction properly includes the new
streamfield from the request payload.tests/test_stream.py (1)
1-39: Streaming test coverage is solid; OpenAI dependency is already specifiedThe
openaimodule is declared in bothrequirements.txtandpyproject.toml, so no dependency changes are needed.Consider extending the test suite to cover:
- Error handling scenarios in the streaming flow
- Integration with the full Chat client rather than a bare
ResponseSynthesizer- Language translation or other transformations during streaming
- End-to-end API endpoint streaming behavior
src/wandbot/chat/rag.py (1)
159-159:stream_outputattribute is correctly definedAfter inspecting
src/wandbot/rag/response_synthesis.py, theResponseSynthesizerclass includes:
self.stream_output = { "query_str":…, "context_str":…, "response": result }Furthermore,
tests/test_stream.pyvalidates its usage (assert synth.stream_output["response"] == "hello world"). No changes needed.src/wandbot/models/llm.py (1)
115-119: LGTM! Good fallback implementation.The base class stream method provides a reasonable fallback by calling the existing create method and yielding the result.
| async def astream(self, chat_request: ChatRequest): | ||
| """Stream the chat response tokens asynchronously.""" | ||
| original_language = chat_request.language | ||
|
|
||
| working_request = chat_request | ||
|
|
||
| if original_language == "ja": | ||
| translated_question = translate_ja_to_en( | ||
| chat_request.question, self.chat_config.ja_translation_model_name | ||
| ) | ||
| working_request = ChatRequest( | ||
| question=translated_question, | ||
| chat_history=chat_request.chat_history, | ||
| application=chat_request.application, | ||
| language="en", | ||
| ) | ||
|
|
||
| async for token in self.rag_pipeline.astream( | ||
| working_request.question, working_request.chat_history or [] | ||
| ): | ||
| yield token | ||
|
|
||
| result = self.rag_pipeline.stream_result | ||
| result_dict = result.model_dump() | ||
|
|
||
| if original_language == "ja": | ||
| result_dict["answer"] = translate_en_to_ja( | ||
| result_dict["answer"], self.chat_config.ja_translation_model_name | ||
| ) | ||
|
|
||
| result_dict.update({"application": chat_request.application}) | ||
| self.last_stream_response = ChatResponse(**result_dict) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Enhance error handling and consistency with the main __acall__ method.
The streaming implementation lacks the comprehensive error handling, timing, and API status tracking present in __acall__. This could result in inconsistent behavior and poor error visibility.
Consider these improvements:
- Add proper exception handling with
ErrorInfotracking - Include timing information using
Timer - Add API status tracking similar to
__acall__ - Handle translation errors gracefully with fallback responses
async def astream(self, chat_request: ChatRequest):
"""Stream the chat response tokens asynchronously."""
original_language = chat_request.language
+ api_call_statuses = {}
working_request = chat_request
+
+ with Timer() as timer:
+ try:
+ # Handle Japanese translation with error handling
+ if original_language == "ja":
+ try:
+ translated_question = translate_ja_to_en(
+ chat_request.question, self.chat_config.ja_translation_model_name
+ )
+ working_request = ChatRequest(
+ question=translated_question,
+ chat_history=chat_request.chat_history,
+ application=chat_request.application,
+ language="en",
+ )
+ except Exception as e:
+ # Handle translation error similar to __acall__
+ api_call_statuses["chat_success"] = False
+ api_call_statuses["chat_error_info"] = ErrorInfo(
+ has_error=True, error_message=str(e),
+ error_type=type(e).__name__, component="translation"
+ ).model_dump()
+ yield f"Translation error: {str(e)}"
+ return
+
+ async for token in self.rag_pipeline.astream(
+ working_request.question, working_request.chat_history or []
+ ):
+ yield token
+
+ except Exception as e:
+ # Handle streaming errors
+ api_call_statuses["chat_success"] = False
+ api_call_statuses["chat_error_info"] = ErrorInfo(
+ has_error=True, error_message=str(e),
+ error_type=type(e).__name__, component="chat"
+ ).model_dump()
+ yield f"Streaming error: {str(e)}"
+ return
+
+ # Store final result with complete metadata
+ result = self.rag_pipeline.stream_result
+ result_dict = result.model_dump()
+
+ # Handle response translation
+ if original_language == "ja":
+ try:
+ result_dict["answer"] = translate_en_to_ja(
+ result_dict["answer"], self.chat_config.ja_translation_model_name
+ )
+ except Exception as e:
+ result_dict["answer"] = f"Translation error: {str(e)}\nOriginal answer: {result_dict['answer']}"
+
+ # Update with complete metadata
+ api_call_statuses["chat_success"] = True
+ result_dict.update({
+ "application": chat_request.application,
+ "api_call_statuses": api_call_statuses,
+ "time_taken": timer.elapsed,
+ "start_time": timer.start,
+ "end_time": timer.stop,
+ })
+
+ self.last_stream_response = ChatResponse(**result_dict)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/wandbot/chat/chat.py lines 237 to 268, the astream method lacks the
robust error handling, timing, and API status tracking found in the __acall__
method, leading to inconsistent behavior and poor error visibility. To fix this,
wrap the streaming logic in try-except blocks to catch exceptions and record
them using ErrorInfo, use a Timer to measure execution time, and update API
status accordingly. Also, handle translation errors gracefully by providing
fallback responses instead of failing outright, ensuring consistent and reliable
streaming behavior.
| async def stream(self, inputs: RetrievalResult): | ||
| """Stream response tokens while capturing the final result.""" | ||
| formatted_input = self._format_input(inputs) | ||
| messages = self.get_messages(formatted_input) | ||
|
|
||
| result = "" | ||
| used_model = self.model | ||
| try: | ||
| async for token in self.model.stream(messages=messages): | ||
| result += token | ||
| yield token | ||
| except Exception as e: | ||
| logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}") | ||
| used_model = self.fallback_model | ||
| async for token in self.fallback_model.stream(messages=messages): | ||
| result += token | ||
| yield token | ||
|
|
||
| self.stream_output = { | ||
| "query_str": formatted_input["query_str"], | ||
| "context_str": formatted_input["context_str"], | ||
| "response": result, | ||
| "response_model": used_model.model_name, | ||
| "response_synthesis_llm_messages": messages, | ||
| "response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT, | ||
| "api_statuses": { | ||
| "response_synthesis_llm_api": None | ||
| }, | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Improve error handling and API status tracking in the streaming method.
The streaming implementation has several issues that could affect reliability and debugging:
- API status is set to
Noneinstead of proper tracking - Fallback error handling doesn't reset the
resultaccumulator - No exception handling for the fallback model failure
- Missing comprehensive error information
async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)
result = ""
used_model = self.model
+ llm_api_status = None
+
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
+ llm_api_status = APIStatus(success=True, error_info=None)
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
- async for token in self.fallback_model.stream(messages=messages):
- result += token
- yield token
+ result = "" # Reset accumulator for fallback
+ try:
+ async for token in self.fallback_model.stream(messages=messages):
+ result += token
+ yield token
+ llm_api_status = APIStatus(success=True, error_info=None)
+ except Exception as fallback_e:
+ logger.error(f"Both primary and fallback Response Synthesizer models failed: {str(fallback_e)}")
+ llm_api_status = APIStatus(success=False, error_info=str(fallback_e))
+ raise Exception(f"Response synthesis streaming failed: {str(fallback_e)}")
self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
- "response_synthesis_llm_api": None
+ "response_synthesis_llm_api": llm_api_status
},
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| async def stream(self, inputs: RetrievalResult): | |
| """Stream response tokens while capturing the final result.""" | |
| formatted_input = self._format_input(inputs) | |
| messages = self.get_messages(formatted_input) | |
| result = "" | |
| used_model = self.model | |
| try: | |
| async for token in self.model.stream(messages=messages): | |
| result += token | |
| yield token | |
| except Exception as e: | |
| logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}") | |
| used_model = self.fallback_model | |
| async for token in self.fallback_model.stream(messages=messages): | |
| result += token | |
| yield token | |
| self.stream_output = { | |
| "query_str": formatted_input["query_str"], | |
| "context_str": formatted_input["context_str"], | |
| "response": result, | |
| "response_model": used_model.model_name, | |
| "response_synthesis_llm_messages": messages, | |
| "response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT, | |
| "api_statuses": { | |
| "response_synthesis_llm_api": None | |
| }, | |
| } | |
| async def stream(self, inputs: RetrievalResult): | |
| """Stream response tokens while capturing the final result.""" | |
| formatted_input = self._format_input(inputs) | |
| messages = self.get_messages(formatted_input) | |
| result = "" | |
| used_model = self.model | |
| llm_api_status = None | |
| try: | |
| async for token in self.model.stream(messages=messages): | |
| result += token | |
| yield token | |
| llm_api_status = APIStatus(success=True, error_info=None) | |
| except Exception as e: | |
| logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}") | |
| used_model = self.fallback_model | |
| result = "" # Reset accumulator for fallback | |
| try: | |
| async for token in self.fallback_model.stream(messages=messages): | |
| result += token | |
| yield token | |
| llm_api_status = APIStatus(success=True, error_info=None) | |
| except Exception as fallback_e: | |
| logger.error(f"Both primary and fallback Response Synthesizer models failed: {str(fallback_e)}") | |
| llm_api_status = APIStatus(success=False, error_info=str(fallback_e)) | |
| raise Exception(f"Response synthesis streaming failed: {str(fallback_e)}") | |
| self.stream_output = { | |
| "query_str": formatted_input["query_str"], | |
| "context_str": formatted_input["context_str"], | |
| "response": result, | |
| "response_model": used_model.model_name, | |
| "response_synthesis_llm_messages": messages, | |
| "response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT, | |
| "api_statuses": { | |
| "response_synthesis_llm_api": llm_api_status | |
| }, | |
| } |
🤖 Prompt for AI Agents
In src/wandbot/rag/response_synthesis.py lines 230 to 258, improve the stream
method by properly tracking API call statuses instead of setting them to None,
reset the result accumulator before using the fallback model to avoid mixing
outputs, add exception handling around the fallback model streaming to catch and
log any errors, and include detailed error information in the logs to aid
debugging and reliability.
| self.stream_result = RAGPipelineOutput( | ||
| question=enhanced_query["standalone_query"], | ||
| answer=response["response"], | ||
| sources="\n".join( | ||
| [doc.metadata["source"] for doc in retrieval_result.documents] | ||
| ), | ||
| source_documents=response["context_str"], | ||
| system_prompt=response["response_prompt"], | ||
| model=response["response_model"], | ||
| total_tokens=0, | ||
| prompt_tokens=0, | ||
| completion_tokens=0, | ||
| time_taken=0, | ||
| start_time=datetime.datetime.now(), | ||
| end_time=datetime.datetime.now(), | ||
| api_call_statuses={ | ||
| "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success, | ||
| "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info, | ||
| "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success, | ||
| "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None, | ||
| "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False, | ||
| "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info, | ||
| "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success, | ||
| }, | ||
| response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"), | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Eliminate code duplication and improve error handling.
The RAGPipelineOutput construction logic is duplicated between __acall__ and astream methods. Consider extracting this into a helper method and adding proper error handling.
+ def _build_pipeline_output(self, enhanced_query, response, retrieval_result) -> RAGPipelineOutput:
+ """Helper method to build RAGPipelineOutput consistently."""
+ return RAGPipelineOutput(
+ question=enhanced_query["standalone_query"],
+ answer=response["response"],
+ sources="\n".join(
+ [doc.metadata["source"] for doc in retrieval_result.documents]
+ ),
+ source_documents=response["context_str"],
+ system_prompt=response["response_prompt"],
+ model=response["response_model"],
+ total_tokens=0,
+ prompt_tokens=0,
+ completion_tokens=0,
+ time_taken=0,
+ start_time=datetime.datetime.now(),
+ end_time=datetime.datetime.now(),
+ api_call_statuses={
+ "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success,
+ "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info,
+ "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success,
+ "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None,
+ "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False,
+ "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info,
+ "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success,
+ },
+ response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"),
+ )Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/wandbot/chat/rag.py around lines 161 to 186, the construction of the
RAGPipelineOutput object is duplicated in both __acall__ and astream methods. To
fix this, extract the RAGPipelineOutput creation logic into a separate helper
method that takes the necessary inputs and returns the constructed object. Then
replace the duplicated code in both methods with calls to this helper.
Additionally, add proper error handling within this helper to catch and manage
any exceptions during the construction process.
| async def astream( | ||
| self, question: str, chat_history: List[Tuple[str, str]] | None = None | ||
| ) -> None: |
There was a problem hiding this comment.
Fix the return type annotation for the async generator.
The method yields tokens but is annotated as returning None. This should be an async generator type.
- async def astream(
- self, question: str, chat_history: List[Tuple[str, str]] | None = None
- ) -> None:
+ async def astream(
+ self, question: str, chat_history: List[Tuple[str, str]] | None = None
+ ) -> AsyncGenerator[str, None]:You'll also need to add the import:
-from typing import Dict, List, Tuple
+from typing import AsyncGenerator, Dict, List, Tuple📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| async def astream( | |
| self, question: str, chat_history: List[Tuple[str, str]] | None = None | |
| ) -> None: | |
| -from typing import Dict, List, Tuple | |
| +from typing import AsyncGenerator, Dict, List, Tuple | |
| async def astream( | |
| self, question: str, chat_history: List[Tuple[str, str]] | None = None | |
| ) -> AsyncGenerator[str, None]: |
🤖 Prompt for AI Agents
In src/wandbot/chat/rag.py around lines 146 to 148, the async method astream is
currently annotated as returning None, but it actually yields tokens as an async
generator. Change the return type annotation to an appropriate async generator
type that reflects the yielded token type. Also, add the necessary import for
the async generator type from the typing module to support this annotation.
| async def stream(self, messages: List[Dict[str, Any]]): | ||
| api_params = { | ||
| "model": self.model_name, | ||
| "temperature": self.temperature, | ||
| "messages": messages, | ||
| "stream": True, | ||
| } | ||
| if api_params["temperature"] == 0: | ||
| api_params["temperature"] = 0.1 | ||
|
|
||
| if self.model_name.startswith("o"): | ||
| api_params.pop("temperature", None) | ||
|
|
||
| async with self.semaphore: | ||
| response = await self.client.chat.completions.create(**api_params) | ||
| async for chunk in response: | ||
| delta = chunk.choices[0].delta.content | ||
| if delta: | ||
| yield delta | ||
|
|
There was a problem hiding this comment.
Add response_model support and error handling to OpenAI streaming.
The streaming implementation doesn't handle the response_model case and lacks error handling compared to the create method.
async def stream(self, messages: List[Dict[str, Any]]):
+ if self.response_model:
+ # For structured outputs, fall back to non-streaming
+ result, _ = await self.create(messages=messages)
+ if result:
+ yield result
+ return
+
+ try:
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1
if self.model_name.startswith("o"):
api_params.pop("temperature", None)
async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta
+ except Exception as e:
+ logger.error(f"OpenAI streaming error: {str(e)}")
+ # Fall back to non-streaming
+ result, _ = await self.create(messages=messages)
+ if result:
+ yield result📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| async def stream(self, messages: List[Dict[str, Any]]): | |
| api_params = { | |
| "model": self.model_name, | |
| "temperature": self.temperature, | |
| "messages": messages, | |
| "stream": True, | |
| } | |
| if api_params["temperature"] == 0: | |
| api_params["temperature"] = 0.1 | |
| if self.model_name.startswith("o"): | |
| api_params.pop("temperature", None) | |
| async with self.semaphore: | |
| response = await self.client.chat.completions.create(**api_params) | |
| async for chunk in response: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| yield delta | |
| async def stream(self, messages: List[Dict[str, Any]]): | |
| # If a response_model is configured, fall back to non-streaming | |
| if self.response_model: | |
| result, _ = await self.create(messages=messages) | |
| if result: | |
| yield result | |
| return | |
| try: | |
| api_params = { | |
| "model": self.model_name, | |
| "temperature": self.temperature, | |
| "messages": messages, | |
| "stream": True, | |
| } | |
| if api_params["temperature"] == 0: | |
| api_params["temperature"] = 0.1 | |
| if self.model_name.startswith("o"): | |
| api_params.pop("temperature", None) | |
| async with self.semaphore: | |
| response = await self.client.chat.completions.create(**api_params) | |
| async for chunk in response: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| yield delta | |
| except Exception as e: | |
| logger.error(f"OpenAI streaming error: {e}") | |
| # Fall back to non-streaming on error | |
| result, _ = await self.create(messages=messages) | |
| if result: | |
| yield result |
🤖 Prompt for AI Agents
In src/wandbot/models/llm.py around lines 188 to 207, the stream method lacks
support for the response_model parameter and does not include error handling
like the create method. Update the stream method to accept and pass the
response_model argument when calling the OpenAI API, and add appropriate
try-except blocks to catch and handle exceptions during the streaming process,
ensuring errors are logged or managed gracefully.
| async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000): | ||
| processed_messages = [] | ||
| for msg in messages: | ||
| if msg.get("role") == "developer": | ||
| processed_messages.append({"role": "system", "content": msg.get("content")}) | ||
| else: | ||
| processed_messages.append(msg) | ||
|
|
||
| system_msg, chat_messages = extract_system_and_messages(processed_messages) | ||
| api_params = { | ||
| "model": self.model_name, | ||
| "temperature": self.temperature, | ||
| "messages": chat_messages, | ||
| "max_tokens": max_tokens, | ||
| "stream": True, | ||
| } | ||
|
|
||
| if api_params["temperature"] == 0: | ||
| api_params["temperature"] = 0.1 | ||
|
|
||
| if system_msg: | ||
| api_params["system"] = system_msg | ||
|
|
||
| if self.response_model: | ||
| api_params["messages"] += add_json_response_model_to_messages(self.response_model) | ||
|
|
||
| async with self.semaphore: | ||
| response = await self.client.messages.create(**api_params) | ||
| async for chunk in response: | ||
| try: | ||
| delta = chunk.delta.text # type: ignore | ||
| except AttributeError: | ||
| delta = "" | ||
| if delta: | ||
| yield delta | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Eliminate code duplication and improve error handling.
The message preprocessing logic is duplicated from the create method. Consider extracting this into a helper method and adding proper error handling.
+ def _preprocess_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Helper method to preprocess messages for Anthropic API."""
+ processed_messages = []
+ for msg in messages:
+ if msg.get("role") == "developer":
+ processed_messages.append({"role": "system", "content": msg.get("content")})
+ logger.debug("Converted 'developer' role to 'system' for Anthropic call.")
+ else:
+ processed_messages.append(msg)
+ return processed_messages
+
async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
- processed_messages = []
- for msg in messages:
- if msg.get("role") == "developer":
- processed_messages.append({"role": "system", "content": msg.get("content")})
- else:
- processed_messages.append(msg)
+ try:
+ processed_messages = self._preprocess_messages(messages)
+ # ... rest of the implementation
+ except Exception as e:
+ logger.error(f"Anthropic streaming error: {str(e)}")
+ # Fall back to non-streaming
+ result, _ = await self.create(messages=messages, max_tokens=max_tokens)
+ if result:
+ yield resultCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/wandbot/models/llm.py between lines 272 and 307, the message
preprocessing logic duplicates code from the create method and lacks robust
error handling. Refactor by extracting the message preprocessing into a separate
helper method that both stream and create methods can call. Add try-except
blocks around the preprocessing steps to catch and handle potential errors
gracefully, ensuring the stream method yields results only when preprocessing
succeeds.
Summary
astreampipeline to LLM streamingTesting
pytest -k test_streaming_response tests/test_stream.py(fails: ModuleNotFoundError: No module named 'openai')https://chatgpt.com/codex/tasks/task_e_686b7f40cdf4832ba2d6ca8a1b3a7570
Summary by CodeRabbit
New Features
Bug Fixes
Tests